{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 网络中的网络（NiN）\n",
    "\n",
    "前几节介绍的LeNet、AlexNet和VGG在设计上的共同之处是：先以由卷积层构成的模块充分抽取空间特征，再以由全连接层构成的模块来输出分类结果。其中，AlexNet和VGG对LeNet的改进主要在于如何对这两个模块加宽（增加通道数）和加深。本节我们介绍网络中的网络（NiN）[1]。它提出了另外一个思路，即串联多个由卷积层和“全连接”层构成的小网络来构建一个深层网络。\n",
    "\n",
    "\n",
    "## NiN块\n",
    "\n",
    "我们知道，卷积层的输入和输出通常是四维数组（样本，通道，高，宽），而全连接层的输入和输出则通常是二维数组（样本，特征）。如果想在全连接层后再接上卷积层，则需要将全连接层的输出变换为四维。回忆在[“多输入通道和多输出通道”](channels.ipynb)一节里介绍的$1\\times 1$卷积层。它可以看成全连接层，其中空间维度（高和宽）上的每个元素相当于样本，通道相当于特征。因此，NiN使用$1\\times 1$卷积层来替代全连接层，从而使空间信息能够自然传递到后面的层中去。图5.7对比了NiN同AlexNet和VGG等网络在结构上的主要区别。\n",
    "\n",
    "![左图是AlexNet和VGG的网络结构局部，右图是NiN的网络结构局部](../img/nin.svg)\n",
    "\n",
    "NiN块是NiN中的基础块。它由一个卷积层加两个充当全连接层的$1\\times 1$卷积层串联而成。其中第一个卷积层的超参数可以自行设置，而第二和第三个卷积层的超参数一般是固定的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "2"
    }
   },
   "outputs": [],
   "source": [
    "import d2lzh as d2l\n",
    "from mxnet import gluon, init, nd\n",
    "from mxnet.gluon import nn\n",
    "\n",
    "def nin_block(num_channels, kernel_size, strides, padding):\n",
    "    blk = nn.Sequential()\n",
    "    blk.add(nn.Conv2D(num_channels, kernel_size,\n",
    "                      strides, padding, activation='relu'),\n",
    "            nn.Conv2D(num_channels, kernel_size=1, activation='relu'),\n",
    "            nn.Conv2D(num_channels, kernel_size=1, activation='relu'))\n",
    "    return blk"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## NiN模型\n",
    "\n",
    "NiN是在AlexNet问世不久后提出的。它们的卷积层设定有类似之处。NiN使用卷积窗口形状分别为$11\\times 11$、$5\\times 5$和$3\\times 3$的卷积层，相应的输出通道数也与AlexNet中的一致。每个NiN块后接一个步幅为2、窗口形状为$3\\times 3$的最大池化层。\n",
    "\n",
    "除使用NiN块以外，NiN还有一个设计与AlexNet显著不同：NiN去掉了AlexNet最后的3个全连接层，取而代之地，NiN使用了输出通道数等于标签类别数的NiN块，然后使用全局平均池化层对每个通道中所有元素求平均并直接用于分类。这里的全局平均池化层即窗口形状等于输入空间维形状的平均池化层。NiN的这个设计的好处是可以显著减小模型参数尺寸，从而缓解过拟合。然而，该设计有时会造成获得有效模型的训练时间的增加。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "attributes": {
     "classes": [],
     "id": "",
     "n": "9"
    }
   },
   "outputs": [],
   "source": [
    "net = nn.Sequential()\n",
    "net.add(nin_block(96, kernel_size=11, strides=4, padding=0),\n",
    "        nn.MaxPool2D(pool_size=3, strides=2),\n",
    "        nin_block(256, kernel_size=5, strides=1, padding=2),\n",
    "        nn.MaxPool2D(pool_size=3, strides=2),\n",
    "        nin_block(384, kernel_size=3, strides=1, padding=1),\n",
    "        nn.MaxPool2D(pool_size=3, strides=2), nn.Dropout(0.5),\n",
    "        # 标签类别数是10\n",
    "        nin_block(10, kernel_size=3, strides=1, padding=1),\n",
    "        # 全局平均池化层将窗口形状自动设置成输入的高和宽\n",
    "        nn.GlobalAvgPool2D(),\n",
    "        # 将四维的输出转成二维的输出，其形状为(批量大小, 10)\n",
    "        nn.Flatten())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们构建一个数据样本来查看每一层的输出形状。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sequential1 output shape:\t (1, 96, 54, 54)\n",
      "pool0 output shape:\t (1, 96, 26, 26)\n",
      "sequential2 output shape:\t (1, 256, 26, 26)\n",
      "pool1 output shape:\t (1, 256, 12, 12)\n",
      "sequential3 output shape:\t (1, 384, 12, 12)\n",
      "pool2 output shape:\t (1, 384, 5, 5)\n",
      "dropout0 output shape:\t (1, 384, 5, 5)\n",
      "sequential4 output shape:\t (1, 10, 5, 5)\n",
      "pool3 output shape:\t (1, 10, 1, 1)\n",
      "flatten0 output shape:\t (1, 10)\n"
     ]
    }
   ],
   "source": [
    "X = nd.random.uniform(shape=(1, 1, 224, 224))\n",
    "net.initialize()\n",
    "for layer in net:\n",
    "    X = layer(X)\n",
    "    print(layer.name, 'output shape:\\t', X.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练模型\n",
    "\n",
    "我们依然使用Fashion-MNIST数据集来训练模型。NiN的训练与AlexNet和VGG的类似，但这里使用的学习率更大。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on gpu(0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, loss 2.2705, train acc 0.150, test acc 0.255, time 24.1 sec\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 2, loss 1.4471, train acc 0.480, test acc 0.670, time 22.9 sec\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 3, loss 0.9368, train acc 0.666, test acc 0.705, time 22.9 sec\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 4, loss 0.8007, train acc 0.707, test acc 0.742, time 22.8 sec\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 5, loss 0.7465, train acc 0.725, test acc 0.753, time 22.8 sec\n"
     ]
    }
   ],
   "source": [
    "lr, num_epochs, batch_size, ctx = 0.1, 5, 128, d2l.try_gpu()\n",
    "net.initialize(force_reinit=True, ctx=ctx, init=init.Xavier())\n",
    "trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, trainer, ctx,\n",
    "              num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* NiN重复使用由卷积层和代替全连接层的$1\\times 1$卷积层构成的NiN块来构建深层网络。\n",
    "* NiN去除了容易造成过拟合的全连接输出层，而是将其替换成输出通道数等于标签类别数的NiN块和全局平均池化层。\n",
    "* NiN的以上设计思想影响了后面一系列卷积神经网络的设计。\n",
    "\n",
    "## 练习\n",
    "\n",
    "* 调节超参数，提高分类准确率。\n",
    "* 为什么NiN块里要有两个$1\\times 1$卷积层？去除其中的一个，观察并分析实验现象。\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "## 参考文献\n",
    "\n",
    "[1] Lin, M., Chen, Q., & Yan, S. (2013). Network in network. arXiv preprint arXiv:1312.4400.\n",
    "\n",
    "## 扫码直达[讨论区](https://discuss.gluon.ai/t/topic/1661)\n",
    "\n",
    "![](../img/qr_nin.svg)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}